import sys
import os
import matplotlib.pyplot as plt

log_dir = sys.argv[1]

fewshot_losses = []
fewshot_accs = []

log_files = [fname for fname in os.listdir(log_dir) if ".log" in fname]
for filename in log_files:
    log_file = os.path.join(log_dir, filename)
    with open(log_file, 'r') as fr:
        for line in fr:
            if "####Few Shot" in line:
                loss, acc = line.split("=")[1].split("/")
                fewshot_losses.append(float(loss))
                fewshot_accs.append(float(acc))

    steps = list(range(len(fewshot_accs)))
    plt.cla()
    plt.xlabel("steps")
    plt.plot(steps, fewshot_accs, label="accuracy")
    plt.legend()
    plt.savefig(f"{log_dir}/{filename.strip('.log')}_Acc.png")

    plt.cla()
    plt.xlabel("steps")
    plt.plot(steps, fewshot_losses, label="loss")
    plt.legend()
    plt.savefig(f"{log_dir}/{filename.strip('.log')}_Loss.png")